package ao.unsupervised.cluster.analysis;

import ao.unsupervised.cluster.space.Domain;
import ao.util.math.rand.MersenneTwisterFast;
import ao.util.math.rand.Rand;

/**
 * User: alex
 * Date: 6-Jun-2009
 * Time: 11:02:38 AM
 *
 *    See http://en.wikipedia.org/wiki/K-means
 */
public class KMeans<T> implements ClusterAnalysis<T>
{
    //--------------------------------------------------------------------
//    private final Long forcedSeed;


    //--------------------------------------------------------------------
    public KMeans()
    {
//        forcedSeed = null;
    }
//    public KMeans(long seed)
//    {
//        forcedSeed = seed;
//    }


    //--------------------------------------------------------------------
    public int[] cluster(Domain<T> items, int nClusters)
    {
        return cluster(initMeans(items, nClusters), items);
    }


    //--------------------------------------------------------------------
    private int[] cluster(
            T         means[],
            Domain<T> items)
    {
        int clusters[] = new int[ items.locationCount() ];

        while (assignmentStep(means, items, clusters)) {
//            displayClustering(clusters, items);
            updateStep(means, items, clusters);
        }

        return clusters;
    }


    //--------------------------------------------------------------------
//    private void displayClustering(
//            byte      clusters[],
//            Domain<T> items)
//    {
//        for (int j = 6; j <= 66; j++) {
//            System.out.print("\t" + j);
//        }
//        System.out.println();
//
//        for (int i = 0; i <= 50; i++) {
//            System.out.print(i + "\t");
//
//            for (int j = 6; j <= 66; j++) {
//                CentroidDomain<Centroid<double[]>, double[]>
//                        vectorDomain =
//                        (CentroidDomain<Centroid<double[]>, double[]>)
//                                items;
//
//                boolean seen = false;
//                for (int k = 0; k < vectorDomain.size(); k++) {
//                    double val[] = vectorDomain.get(k);
//
//                    if (i == (int) Math.round(val[0] * 100) &&
//                            j == (int) Math.round(val[1] * 100)) {
//                        System.out.print(Integer.toString(
//                                clusters[k], 16));
//                    }
//                }
//
//                if (! seen) {
//                    System.out.print("\t");
//                }
//            }
//            System.out.println();
//        }
//    }


    //--------------------------------------------------------------------
    // Assign each observation to the cluster with the closest mean
    //  (i.e. partition the observations according to the
    //          Voronoi diagram generated by the means).
    private boolean assignmentStep(
            T         means[],
            Domain<T> items,
            int       clusters[])
    {
        boolean changeMade = false;
        for (int i = 0; i < items.locationCount(); i++) {
//            double strength = strengths.strengthNorm(i);

            int    leastDistIndex = -1;
            double leastDistance  = Double.POSITIVE_INFINITY;
            for (int j = 0; j < means.length; j++) {

                double distance = items.distanceBetween(means[j], i);
                if (leastDistance  > distance) {
                    leastDistance  = distance;
                    leastDistIndex = j;
                }
            }

            changeMade    |= (clusters[ i ] != leastDistIndex);
            clusters[ i ]  = leastDistIndex;
        }
        return changeMade;
    }


    //--------------------------------------------------------------------
    // Calculate the new means to be the centroid of the
    //   observations in the cluster.
    private void updateStep(
            T         means[],
            Domain<T> items,
            int       clusters[])
    {
        for (int i = 0; i < means.length; i++)
        {
            means[i] = items.newCentroid();
        }

        for (int i = 0; i < clusters.length; i++)
        {
            items.mergeAll(means[ clusters[i] ], i);
        }
    }


    //--------------------------------------------------------------------
    @SuppressWarnings("unchecked")
    private T[] initMeans(Domain<T> items, int nClusters) {
        MersenneTwisterFast rand =
                new MersenneTwisterFast(
                        seed(items, nClusters));

        int means[] = new int[ nClusters ];

        // Choose one center uniformly at random
        //  from among the data points.
        means[0] = rand.nextInt(items.locationCount());

        for (int k = 1; k < nClusters; k++)
        {
            // For each data point x, compute D(x), the distance between
            //   x and the nearest center that has already been chosen.

            double maxChance      = Double.NEGATIVE_INFINITY;
            int    maxChanceIndex = -1;

            for (int i = 0; i < items.locationCount(); i++) {

                double nearestCluster = Double.POSITIVE_INFINITY;
                for (int j = 0; j < k; j++) {
                    // for each previously established cluster
                    double dist = items.distanceBetween(i, means[j]);
                    if (nearestCluster > dist) {
                        nearestCluster = dist;
                    }
                }

                // Each point x is chosen with
                //  probability proportional to D(x)^2.
                double chance = rand.nextDouble()
                        * nearestCluster * nearestCluster
                        * items.represents(i);
                if (maxChance < chance) {
                    maxChance      = chance;
                    maxChanceIndex = i;
                }
            }

            // Add one new data point as a center.
            means[k] = maxChanceIndex;
        }

        T meanVals[] = (T[]) new Object[ means.length ];
        for (int i = 0; i < means.length; i++) {
            meanVals[ i ] = items.newCentroid();
            items.mergeAll(meanVals[i], means[i]);
        }
        return meanVals;
    }

    private long seed(Domain items, int nClusters) {
        return Rand.nextLong();
//        return forcedSeed != null
//               ? forcedSeed
//               :
////                items.locationCount() * nClusters
//                Rand.nextLong()
//                ;
    }
}
